package com.lw.leetcode.binary.c;

/**
 * Created with IntelliJ IDEA.
 * arr
 * c
 * 4. 寻找两个正序数组的中位数
 *
 * @author liw
 * @version 1.0
 * @date 2021/8/20 13:29
 */
public class FindMedianSortedArrays {


    public static void main(String[] args) {
        FindMedianSortedArrays test = new FindMedianSortedArrays();

        // 2.5
//        int[] arr1 = {1, 2, 3};
//        int[] arr2 = {2, 5, 8};

        // 3.5
//        int[] arr1 = {1, 2,3};
//        int[] arr2 = {4,5, 8};

        // 5.5
//        int[] arr1 = {};
//        int[] arr2 = {4, 5, 6, 10};

        // 2
//        int[] arr1 = {1, 3};
//        int[] arr2 = {2};

        // 2.5
//        int[] arr1 = {1, 2};
//        int[] arr2 = {3, 4};

        // 1
//        int[] arr1 = {1, 1};
//        int[] arr2 = {1, 1};

        // 1
//        int[] arr1 = {};
//        int[] arr2 = {1};

        // 2
//        int[] arr1 = {2};
//        int[] arr2 = {};

        // 2
        int[] arr1 = {1, 2, 2};
        int[] arr2 = {1, 2, 3};

        // 1.5
//        int[] arr1 = {1, 2};
//        int[] arr2 = {-1, 3};

        System.out.println(test.findMedianSortedArrays(arr1, arr2));
    }

    private int[] nums1;
    private int[] nums2;

    public double findMedianSortedArrays(int[] nums1, int[] nums2) {
        this.nums1 = nums1;
        this.nums2 = nums2;
        int a = nums1.length;
        int b = nums2.length;
        int s = a + b;
        int m = s >> 1;
        if ((s & 1) == 0) {
            return (find(0, a - 1, 0, b - 1, m + 1) + find(0, a - 1, 0, b - 1, m)) / 2.0;
        }
        return find(0, a - 1, 0, b - 1, m + 1);
    }

    private int find(int st1, int end1, int st2, int end2, int c) {
        if (st1 > end1) {
            return nums2[st2 + c - 1];
        }
        if (st2 > end2) {
            return nums1[st1 + c - 1];
        }
        if (c == 1) {
            return Math.min(nums1[st1], nums2[st2]);
        }
        int m1 = st1 + ((end1 - st1) >> 1);
        int m2 = st2 + ((end2 - st2) >> 1);
        int m = m1 - st1 + m2 - st2 + 2;
        if (nums1[m1] == nums2[m2]) {
            if (m == c || m - 1 == c) {
                return nums1[m1];
            } else if (m > c) {
                return find(st1, m1 - 1, st2, m2 - 1, c);
            }
            return find(m1 + 1, end1, m2 + 1, end2, c - m);
        } else if (nums1[m1] > nums2[m2]) {
            if (m > c) {
                return find(st1, m1 - 1, st2, end2, c);
            } else if (m < c) {
                int t = m2 - st2 + 1;
                return find(st1, end1, m2 + 1, end2, c - t);
            } else {
                int t = m2 - st2 + 1;
                return find(st1, m1, m2 + 1, end2, c - t);
            }
        } else {
            if (m > c) {
                return find(st1, end1, st2, m2 - 1, c);
            } else if (m < c) {
                int t = m1 - st1 + 1;
                return find(m1 + 1, end1, st2, end2, c - t);
            } else {
                int t = m1 - st1 + 1;
                return find(m1 + 1, end1, st2, m2, c - t);
            }
        }
    }

}
